{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_circles, load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import yellowbrick as yb\n",
    "import matplotlib\n",
    "import matplotlib.pylab as plt\n",
    "\n",
    "# dtype = torch.long\n",
    "# device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data & prepare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = make_circles(n_samples=1000, noise=0.1)\n",
    "\n",
    "# 75/25 train/test split\n",
    "orig_X_train, orig_X_test, orig_y_train, orig_y_test = train_test_split(X, y, test_size=0.25)\n",
    "\n",
    "# Transform data into tensors.\n",
    "X = torch.tensor(orig_X_train, dtype=torch.float)\n",
    "y = torch.tensor(orig_y_train, dtype=torch.long)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yellowbrick.contrib.scatter\n",
    "visualizer = yellowbrick.contrib.scatter.ScatterVisualizer()\n",
    "\n",
    "visualizer.fit(orig_X_train, orig_y_train)\n",
    "visualizer.poof()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic Neural Net\n",
    "\n",
    "3 things are needed for an optimization problem:\n",
    "1. model\n",
    "2. Loss function\n",
    "3. Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "# Sequential model allows easy model experimentation\n",
    "model = nn.Sequential(\n",
    "          nn.Linear(2, 16),    # input dim 2. 16 neurons in first layer.\n",
    "          nn.ReLU(),           # ReLU activation\n",
    "          #nn.Dropout(p=0.2),  # Optional dropout\n",
    "          nn.Linear(16, 4),     # Linear from 16 neurons down to 2\n",
    "          nn.ReLU(),\n",
    "          nn.Linear(4,2),\n",
    "          nn.Softmax(dim=1)    # Softmax activation to normalize output weights\n",
    "        )\n",
    "\n",
    "\n",
    "# Loss function. CrossEntropy is valid for classification problems.\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "# Optimizer. Many to choose from. \n",
    "optimizer = torch.optim.Adam(params=model.parameters())\n",
    "\n",
    "# Optimizer iterations\n",
    "for i in range(1000):\n",
    "    # Clear the gradient at the start of each step.\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Compute the forward pass\n",
    "    output = model(X)\n",
    "    \n",
    "    # Compute the loss\n",
    "    loss = loss_fn(output, y)\n",
    "    \n",
    "    # Backprop to compute the gradients\n",
    "    loss.backward()\n",
    "    \n",
    "    # Update the model parameters\n",
    "    optimizer.step()\n",
    "\n",
    "print(loss.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What do the activation regions look like?\n",
    "\n",
    "(an exercise in Tensor math)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# Make a grid \n",
    "ns = 25\n",
    "xx, yy = np.meshgrid(np.linspace(-1.5, 1.5, 2*ns), np.linspace(-1.5, 1.5, 2*ns))\n",
    "# Shape of each is [ns, ns]\n",
    "\n",
    "# Combine into a single tensor\n",
    "G = torch.tensor(np.array([xx, yy]), dtype=torch.float)\n",
    "# Shape is [2, ns, ns]\n",
    "\n",
    "# reshape to be convenient to work with\n",
    "G = G.reshape((2, G.shape[1]*G.shape[2])).transpose(0,1)\n",
    "# Now a tensor of shape [ns*ns, 2]. Sequence of x,y coordinate pairs\n",
    "\n",
    "result = model(G).detach()\n",
    "# For each row (sample) in G, get the prediction under the model\n",
    "# The variables inside the model are tracked for gradients. \n",
    "# Call \"detach()\" to stop tracking gradient for further computations.\n",
    "# Result is shape [ns*ns, 2] since model takes 2-dim vectors and generates a 2-dim prediction\n",
    "\n",
    "c0 = result[:,0]\n",
    "# weights assigned to class 0\n",
    "\n",
    "c1 = result[:,1]\n",
    "# weights assigned to class 1\n",
    "\n",
    "plt.hexbin(G[:,0].detach().numpy(), G[:,1].detach().numpy(), c0.numpy(), gridsize=ns, cmap='viridis')\n",
    "# Gridsize is half that of the meshgrid for clean rendering.\n",
    "\n",
    "plt.title(\"Class 0 Activation\")\n",
    "plt.axis('equal')\n",
    "plt.show()\n",
    "plt.hexbin(G[:,0].detach().numpy(), G[:,1].detach().numpy(), c1.numpy(), gridsize=ns, cmap='viridis')\n",
    "plt.title(\"Class 1 Activation\")\n",
    "plt.axis('equal')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What is the classification performance?\n",
    "\n",
    "Case study in working with Yellowbrick"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.base import BaseEstimator\n",
    "\n",
    "class NetWrapper(BaseEstimator):\n",
    "    \"\"\"\n",
    "    Wrap our model as a BaseEstimator\n",
    "    \"\"\"\n",
    "    _estimator_type = \"classifier\"\n",
    "    # Tell yellowbrick this is a classifier\n",
    "    \n",
    "    def __init__(self, model):\n",
    "        # save a reference to the model\n",
    "        self.model = model\n",
    "        self.classes_ = None\n",
    "    \n",
    "    def fit(self, X, y):\n",
    "        # save the list of classes\n",
    "        self.classes_ = list(set(i for i in y))\n",
    "    \n",
    "    def predict_proba(self, X):\n",
    "        \"\"\"\n",
    "        Define predict_proba or decision_function\n",
    "        \n",
    "        Compute predictions from model. \n",
    "        Transform input into a Tensor, compute the prediction, \n",
    "        transform the prediction back into a numpy array\n",
    "        \"\"\"\n",
    "        v = model(torch.tensor(X, dtype=torch.float)).detach().numpy()\n",
    "        print(\"v:\", v.shape)\n",
    "        return v\n",
    "        \n",
    "\n",
    "wrapped_net = NetWrapper(model)\n",
    "# Wrap the model\n",
    "\n",
    "# Use ROCAUC as per usual\n",
    "ROCAUC = yb.classifier.ROCAUC(wrapped_net)\n",
    "\n",
    "ROCAUC.fit(orig_X_train, orig_y_train)\n",
    "print(orig_X_test.shape, orig_y_test.shape)\n",
    "print(orig_X_train.shape, orig_y_train.shape)\n",
    "ROCAUC.score(orig_X_test, orig_y_test)\n",
    "ROCAUC.poof()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom Modules\n",
    "\n",
    "Implementing new functionality, e.g. radial activation regions for \"circular\" neurons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# weight: a * (x-c)^T(x-c), a is a real number\n",
    "\n",
    "class Circle(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Extend torch.nn.Module for a new \"layer\" in a neural network\n",
    "    \"\"\"\n",
    "    def __init__(self, k, data):\n",
    "        \"\"\"\n",
    "        k is the number of neurons to use\n",
    "        data is passed in to use as samples to initialize centers\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        \n",
    "        # k is not a Parameter, so there is no gradient and this is not updated in optimization\n",
    "        self.k = int(k)\n",
    "        \n",
    "        # Parameters always have gradients computed\n",
    "        self.alpha = torch.nn.Parameter(torch.normal(mean=torch.zeros(k), std=torch.ones(k)*0.5).unsqueeze(1))\n",
    "        self.C = torch.nn.Parameter(data[np.random.choice(data.shape[0], k, replace=False), :].unsqueeze(1))\n",
    "        \n",
    "        \n",
    "    def forward(self, x): \n",
    "        diff = (x - self.C)        \n",
    "        # compact way of writing inner products, outer products, etc.\n",
    "        tmp = torch.einsum('kij,kij->ki', [diff, diff])\n",
    "\n",
    "        return (self.alpha * torch.einsum('kij,kij->ki', [diff, diff])).transpose(0,1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "loss_fn = torch.nn.CrossEntropyLoss()\n",
    "model = nn.Sequential(\n",
    "          Circle(16, X),\n",
    "          nn.ReLU(),\n",
    "          nn.Linear(16,4),\n",
    "          nn.ReLU(),\n",
    "          nn.Linear(4,2),\n",
    "          nn.Softmax(dim=1)\n",
    "        )\n",
    "optimizer = torch.optim.Adam(params=model.parameters())\n",
    "for i in tqdm(range(1000)):\n",
    "    optimizer.zero_grad()\n",
    "    output = model(X)\n",
    "    loss = loss_fn(output, y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "ns = 25\n",
    "xx, yy = np.meshgrid(np.linspace(-1.5, 1.5, 2*ns), np.linspace(-1.5, 1.5, 2*ns))\n",
    "G = torch.tensor(np.array([xx, yy]), dtype=torch.float)\n",
    "\n",
    "\n",
    "# reshape...\n",
    "G = G.reshape((2, G.shape[1]*G.shape[2])).transpose(0,1)\n",
    "result = model(G).detach()\n",
    "\n",
    "c0 = result[:,0]\n",
    "c1 = result[:,1]\n",
    "\n",
    "plt.hexbin(G[:,0].detach().numpy(), G[:,1].detach().numpy(), c0.numpy(), gridsize=ns, cmap='viridis')\n",
    "plt.title(\"Class 0 Activation\")\n",
    "plt.axis('equal')\n",
    "plt.show()\n",
    "plt.hexbin(G[:,0].detach().numpy(), G[:,1].detach().numpy(), c1.numpy(), gridsize=ns, cmap='viridis')\n",
    "plt.title(\"Class 1 Activation\")\n",
    "plt.axis('equal')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wrapped_net = NetWrapper(model)\n",
    "ROCAUC = yb.classifier.ROCAUC(wrapped_net)\n",
    "\n",
    "ROCAUC.fit(orig_X_train, orig_y_train)\n",
    "wrapped_net.predict_proba(orig_X_test)\n",
    "ROCAUC.score(orig_X_test, orig_y_test)\n",
    "ROCAUC.poof()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# Show the centers of each \"kernel\" \n",
    "\n",
    "centers = model[0].C.squeeze().detach().numpy()\n",
    "scales = model[0].alpha.squeeze().detach().numpy()\n",
    "\n",
    "plt.scatter(centers[:,0], centers[:,1])\n",
    "plt.scatter(X[:,0], X[:,1], alpha=0.1)\n",
    "plt.axis('equal')\n",
    "\n",
    "print(centers.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import cm\n",
    "\n",
    "# Show the contours of the activation regions of each kernel\n",
    "\n",
    "ns = 25\n",
    "xx, yy = np.meshgrid(np.linspace(-2, 2, ns), np.linspace(-2, 2, ns))\n",
    "G = torch.tensor(np.array([xx, yy]), dtype=torch.float)\n",
    "G = G.reshape((2, G.shape[1]*G.shape[2])).transpose(0,1)\n",
    "G = G.expand(centers.shape[0], ns*ns, 2)\n",
    "Z = torch.tensor(scales).unsqueeze(1) * torch.einsum('kij,kij->ki', [G-torch.tensor(centers).unsqueeze(1), G-torch.tensor(centers).unsqueeze(1)])\n",
    "\n",
    "plt.scatter(centers[:,0], centers[:,1])\n",
    "plt.scatter(X[:,0], X[:,1], alpha=0.1)\n",
    "cmap = cm.get_cmap('tab20')\n",
    "for i in range(Z.shape[0]):\n",
    "    if scales[i] > 0:   \n",
    "        plt.contour(np.linspace(-2, 2, ns), np.linspace(-2, 2, ns), Z[i].reshape(ns, ns), [-0.5,0.5], antialiased=True, colors=[cmap(i)], alpha=0.8, linestyles='dotted')\n",
    "    else:\n",
    "        plt.contour(np.linspace(-2, 2, ns), np.linspace(-2, 2, ns), Z[i].reshape(ns, ns), [-0.5,0.5], antialiased=True, colors=[cmap(i)], alpha=0.3, linestyles='solid')\n",
    "\n",
    "plt.axis('equal')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
